#!/usr/bin/env python3
import os, json, shutil
import numpy as np, pandas as pd

STACKS   = "data/prestacked_stacks.csv"    # after randoms subtraction
PLATEAU  = "outputs/lensing_plateau.csv"
WINDOWS  = "outputs/windows.json"

def load_windows(path):
    W = json.load(open(path, "r"))
    out = {}
    # accept dict or list; accept (i0,i1) or (b_min,b_max)
    if isinstance(W, dict):
        for sid, v in W.items():
            if isinstance(v, dict):
                i0, i1 = v.get("i0"), v.get("i1")
                if i0 is not None and i1 is not None:
                    out[sid] = ("idx", int(i0), int(i1))
                else:
                    bmin = v.get("b_min") or v.get("bmin")
                    bmax = v.get("b_max") or v.get("bmax")
                    if bmin is not None and bmax is not None:
                        out[sid] = ("b", float(bmin), float(bmax))
    elif isinstance(W, list):
        for v in W:
            sid = v.get("stack_id") or v.get("id")
            if not sid: continue
            i0, i1 = v.get("i0"), v.get("i1")
            if i0 is not None and i1 is not None:
                out[sid] = ("idx", int(i0), int(i1))
            else:
                bmin, bmax = v.get("b_min"), v.get("b_max")
                if bmin is not None and bmax is not None:
                    out[sid] = ("b", float(bmin), float(bmax))
    return out

def main():
    if not (os.path.exists(STACKS) and os.path.exists(PLATEAU) and os.path.exists(WINDOWS)):
        raise SystemExit("Missing one of required files: stacks/plateau/windows.")
    win = load_windows(WINDOWS)
    if not win:
        raise SystemExit("No windows parsed from outputs/windows.json")

    # load stacks and compute raw P(b)=gamma_t*b
    S = pd.read_csv(STACKS).sort_values(["stack_id","b"])
    if not {"stack_id","b","gamma_t"}.issubset(S.columns):
        raise SystemExit("prestacked_stacks.csv missing required columns")

    S["P"] = S["gamma_t"] * S["b"]

    # compute unsmoothed A_theta per stack using the window
    recs = []
    for sid, g in S.groupby("stack_id"):
        mode, a, b = win.get(sid, (None, None, None))
        if mode is None:
            recs.append((sid, np.nan))
            continue
        if mode == "idx":
            gg = g.reset_index(drop=True)
            i0, i1 = max(0,int(a)), min(len(gg), int(b))
            vec = gg["P"].iloc[i0:i1].to_numpy()
        else:
            bmin, bmax = float(a), float(b)
            vec = g[(g["b"]>=bmin) & (g["b"]<=bmax)]["P"].to_numpy()
        A_uns = float(np.nanmedian(vec)) if vec.size else np.nan
        recs.append((sid, A_uns))

    U = pd.DataFrame(recs, columns=["stack_id","A_theta_unsmoothed"])

    # merge into plateau, keep original as _smoothed and replace A_theta with unsmoothed
    P = pd.read_csv(PLATEAU)
    backup = PLATEAU.replace(".csv", "_smoothed_backup.csv")
    shutil.copy2(PLATEAU, backup)

    Q = P.merge(U, on="stack_id", how="left")
    if "A_theta_smoothed" not in Q.columns:
        Q["A_theta_smoothed"] = Q["A_theta"]
    m = Q["A_theta_unsmoothed"].notna()
    Q.loc[m, "A_theta"] = Q.loc[m, "A_theta_unsmoothed"]

    Q.to_csv(PLATEAU, index=False)
    print(f"Updated {PLATEAU} with unsmoothed A_theta; backup saved to {backup}.")
    # quick sanity print
    print(Q.loc[m, ["stack_id","A_theta_smoothed","A_theta"]].head(6).to_string(index=False))

if __name__ == "__main__":
    main()
